#pragma once

#include <iostream>
#include <assert.h>
#include <ctime>
using namespace std;

enum Colour
{
    RED,
    BLACK
};

template <class K, class V>
struct RBTreeNode
{
    // 这里更新控制平衡也要加入parent指针
    pair<K, V> _kv;
    RBTreeNode<K, V> *_left;
    RBTreeNode<K, V> *_right;
    RBTreeNode<K, V> *_parent;
    Colour _col;

    RBTreeNode(const pair<K, V> &kv)
        : _kv(kv), _left(nullptr), _right(nullptr), _parent(nullptr)
    {
    }
};

template <class K, class V>
class RBTree
{
    typedef RBTreeNode<K, V> Node;

public:
    bool Insert(const pair<K, V> &kv)
    {
        if (_root == nullptr)
        {
            _root = new Node(kv);
            _root->_col = BLACK;

            return true;
        }

        Node *parent = nullptr;
        Node *cur = _root;
        while (cur)
        {
            if (cur->_kv.first < kv.first)
            {
                parent = cur;
                cur = cur->_right;
            }
            else if (cur->_kv.first > kv.first)
            {
                parent = cur;
                cur = cur->_left;
            }
            else
            {
                return false;
            }
        }

        cur = new Node(kv);
        cur->_col = RED;
        if (parent->_kv.first < kv.first)
        {
            parent->_right = cur;
        }
        else
        {
            parent->_left = cur;
        }
        // 链接父亲
        cur->_parent = parent;

        // 父亲是红色，出现连续的红色节点，需要处理
        while (parent && parent->_col == RED)
        {
            Node *grandfather = parent->_parent;
            if (parent == grandfather->_left)
            {
                //   g
                // p   u
                Node *uncle = grandfather->_right;
                if (uncle && uncle->_col == RED)
                {
                    // 变色
                    parent->_col = uncle->_col = BLACK;
                    grandfather->_col = RED;

                    // 继续往上处理
                    cur = grandfather;
                    parent = cur->_parent;
                }
                else
                {
                    if (cur == parent->_left)
                    {
                        //     g
                        //   p    u
                        // c
                        RotateR(grandfather);
                        parent->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    else
                    {
                        //      g
                        //   p    u
                        //     c
                        RotateL(parent);
                        RotateR(grandfather);
                        cur->_col = BLACK;
                        grandfather->_col = RED;
                    }

                    break;
                }
            }
            else
            {
                //   g
                // u   p
                Node *uncle = grandfather->_left;
                // 叔叔存在且为红，-》变色即可
                if (uncle && uncle->_col == RED)
                {
                    parent->_col = uncle->_col = BLACK;
                    grandfather->_col = RED;

                    // 继续往上处理
                    cur = grandfather;
                    parent = cur->_parent;
                }
                else // 叔叔不存在，或者存在且为黑
                {
                    // 情况二：叔叔不存在或者存在且为黑
                    // 旋转+变色
                    //   g
                    // u   p
                    //       c
                    if (cur == parent->_right)
                    {
                        RotateL(grandfather);
                        parent->_col = BLACK;
                        grandfather->_col = RED;
                    }
                    else
                    {
                        RotateR(parent);
                        RotateL(grandfather);
                        cur->_col = BLACK;
                        grandfather->_col = RED;
                    }

                    break;
                }
            }
        }

        _root->_col = BLACK;

        return true;
    }

    void RotateR(Node *parent)
    {
        Node *subL = parent->_left;
        Node *subLR = subL->_right;

        parent->_left = subLR;
        if (subLR)
            subLR->_parent = parent;

        Node *pParent = parent->_parent;

        subL->_right = parent;
        parent->_parent = subL;

        if (parent == _root)
        {
            _root = subL;
            subL->_parent = nullptr;
        }
        else
        {
            if (pParent->_left == parent)
            {
                pParent->_left = subL;
            }
            else
            {
                pParent->_right = subL;
            }

            subL->_parent = pParent;
        }
    }

    void RotateL(Node *parent)
    {
        Node *subR = parent->_right;
        Node *subRL = subR->_left;
        parent->_right = subRL;
        if (subRL)
            subRL->_parent = parent;

        Node *parentParent = parent->_parent;
        subR->_left = parent;
        parent->_parent = subR;
        if (parentParent == nullptr)
        {
            _root = subR;
            subR->_parent = nullptr;
        }
        else
        {
            if (parent == parentParent->_left)
            {
                parentParent->_left = subR;
            }
            else
            {
                parentParent->_right = subR;
            }
            subR->_parent = parentParent;
        }
    }

    void InOrder()
    {
        _InOrder(_root);
        cout << endl;
    }

    int Height()
    {
        return _Height(_root);
    }

    int Size()
    {
        return _Size(_root);
    }

    Node *Find(const K &key)
    {
        Node *cur = _root;
        while (cur)
        {
            if (cur->_kv.first < key)
            {
                cur = cur->_right;
            }
            else if (cur->_kv.first > key)
            {
                cur = cur->_left;
            }
            else
            {
                return cur;
            }
        }

        return nullptr;
    }

    bool IsBalance()
    {
        if (_root == nullptr)
            return true;

        if (_root->_col == RED)
            return false;

        // 参考值
        int refNum = 0;
        Node *cur = _root;
        while (cur)
        {
            if (cur->_col == BLACK)
            {
                ++refNum;
            }
            cur = cur->_left;
        }

        return Check(_root, 0, refNum);
    }

private:
    bool Check(Node *root, int blackNum, const int refNum)
    {
        if (root == nullptr)
        {
            // 前序遍历走到空时，意味着一条路径走完了
            // cout << blackNum << endl;
            if (refNum != blackNum)
            {
                cout << "存在黑色结点的数量不相等的路径" << endl;
                return false;
            }
            return true;
        }

        // 检查孩子不太方便，因为孩子有两个，且不一定存在，反过来检查父亲就方便多了
        if (root->_col == RED && root->_parent->_col == RED)
        {
            cout << root->_kv.first << "存在连续的红色结点" << endl;
            return false;
        }

        if (root->_col == BLACK)
        {
            blackNum++;
        }

        return Check(root->_left, blackNum, refNum) && Check(root->_right, blackNum, refNum);
    }

    void _InOrder(Node *root)
    {
        if (root == nullptr)
        {
            return;
        }

        _InOrder(root->_left);
        cout << root->_kv.first << ":" << root->_kv.second << endl;
        _InOrder(root->_right);
    }

    int _Height(Node *root)
    {
        if (root == nullptr)
            return 0;
        int leftHeight = _Height(root->_left);
        int rightHeight = _Height(root->_right);
        return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1;
    }

    int _Size(Node *root)
    {
        if (root == nullptr)
            return 0;

        return _Size(root->_left) + _Size(root->_right) + 1;
    }

private:
    Node *_root = nullptr;
};
